from callbacks import GiveModelToEnvCallback, FixPolicyActionsCallback, CustomCheckpointCallback
from stable_baselines3.common.callbacks import CallbackList

from ai_collusion.gym_envs.envs.bertrand_competition import RestartExplorationRateWrapperNonEpisodic, BertrandCompetitionDiscreteEnv, EvaluationAfterJohnsonConvergence

from env_setups import get_standard_env, get_no_Stackelberg_env, get_johnson_env
from rl_trainer_setup import get_custom_training_algorithm

import os
import utils
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


def get_latest_checkpoint(folder_path):
    max_steps = -1
    latest_checkpoint = None
    for file in os.listdir(folder_path):
        if file.split('.')[1] == 'zip' and int(file.split('_')[2]) > max_steps:
            max_steps = int(file.split('_')[2])
            latest_checkpoint = folder_path+"/"+file.split('.')[0]
    return latest_checkpoint


def train_run(config_dict):

    tot_num_reward_steps = config_dict['tot_num_reward_steps']
    tot_num_eq_steps = config_dict['tot_num_eq_steps']
    frac_excluded_eq_steps = config_dict['frac_excluded_eq_steps']

    exp_name = "exp.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s" % (
        config_dict['gamma'],
        config_dict['dp_type'],
        config_dict['max_steps'],
        config_dict['algorithm'],
        config_dict['seed'],
        config_dict['tot_num_reward_steps'],
        config_dict['tot_num_eq_steps'],
        config_dict['bbox_state_space_type'],
        config_dict['frac_excluded_eq_steps'],
        config_dict['reward_step_random_price_prob'],
        config_dict['q_restart_rate'],
        config_dict['marginal_cost'],
        config_dict['critic_obs'],
    )

    if config_dict['experiment_type']=='train_with_cost_perturbation':
        exp_name = exp_name + "_cost_perturbation"

    log_folder = os.path.join(
        os.path.abspath(os.path.dirname(__file__)) + "/logs/", exp_name
    )

    config_dict["log_folder"] = log_folder

    env = get_standard_env(config_dict)

    giveModelToEnvCallback = GiveModelToEnvCallback()

    fixPolicyActionsCallback = FixPolicyActionsCallback()

    checkpointCallback = CustomCheckpointCallback(save_freq=10000000, save_path=log_folder)

    mod = get_custom_training_algorithm(config_dict['algorithm'], env, n_steps=tot_num_eq_steps - int(frac_excluded_eq_steps * tot_num_eq_steps) + tot_num_reward_steps)

    mod.learn(
        total_timesteps=config_dict['max_steps'],
        callback=CallbackList([giveModelToEnvCallback, fixPolicyActionsCallback, checkpointCallback]),
    )


def train_no_Stack_run(config_dict):

    exp_name = "exp_no_stack.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s" % (
        config_dict['gamma'],
        config_dict['dp_type'],
        config_dict['max_steps'],
        config_dict['algorithm'],
        config_dict['seed'],
        config_dict['tot_num_reward_steps'],
        config_dict['tot_num_eq_steps'],
        config_dict['bbox_state_space_type'],
        config_dict['frac_excluded_eq_steps'],
        config_dict['reward_step_random_price_prob'],
        config_dict['q_restart_rate'],
        config_dict['marginal_cost'],
    )

    log_folder = os.path.join(
        os.path.abspath(os.path.dirname(__file__)) + "/logs/", exp_name
    )

    config_dict["log_folder"] = log_folder

    env = get_no_Stackelberg_env(config_dict)

    env = RestartExplorationRateWrapperNonEpisodic(env, config_dict['q_restart_rate'],
                                                   expected_num_steps_between_restarts = config_dict['tot_num_reward_steps']+config_dict['tot_num_eq_steps'])

    checkpointCallback = CustomCheckpointCallback(save_freq=10000000, save_path=log_folder)

    mod = get_custom_training_algorithm(config_dict['algorithm'], env)

    mod.learn(
        total_timesteps=config_dict['max_steps'],
        callback=checkpointCallback,
    )


def test_run(config_dict):

    prefix = "exp_no_stack" if config_dict["test_type"]=='no_stackelberg' else "exp"

    exp_name = prefix + ".%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s" % (
        config_dict['gamma'],
        config_dict['dp_type'],
        config_dict['max_steps'],
        config_dict['algorithm'],
        config_dict['seed'],
        config_dict['tot_num_reward_steps'],
        config_dict['tot_num_eq_steps'],
        config_dict['bbox_state_space_type'],
        config_dict['frac_excluded_eq_steps'],
        config_dict['reward_step_random_price_prob'],
        config_dict['q_restart_rate'],
        config_dict['marginal_cost'],
    )

    log_folder = os.path.join(
        os.path.abspath(os.path.dirname(__file__)) + "/test_logs/test_costs/", exp_name
    )
    config_dict["log_folder"] = log_folder

    chkp_folder = os.path.join(
        os.path.abspath(os.path.dirname(__file__)) + "/test_logs/", exp_name
    )

    config_dict['reward_step_random_price_prob'] = 0

    if config_dict["test_type"] == 'no_stackelberg':
        eval_env = get_no_Stackelberg_env(config_dict)
    else:
        eval_env = get_standard_env(config_dict)
    eval_env = EvaluationAfterJohnsonConvergence(eval_env)

    chkpt_path = get_latest_checkpoint(chkp_folder)

    if config_dict['algorithm'] == 'A2C':
        from stable_baselines3 import A2C
        try:
            mod = A2C.load(chkpt_path, env=eval_env)
        except (UnboundLocalError, SystemError):
            import rl_trainer_setup
            mod = A2C.load(chkpt_path, env=eval_env, custom_objects = {'policy_class': rl_trainer_setup.CustomPolicy})

    from stable_baselines3.common.evaluation import evaluate_policy

    for current_env in utils.get_all_wrappers(eval_env):
        if type(current_env) == BertrandCompetitionDiscreteEnv:
            costs = [1.0,
                     (current_env.action_price_space[1] + current_env.action_price_space[2]) / 2,
                     (current_env.action_price_space[2] + current_env.action_price_space[3]) / 2
                     ]

    for x in costs:
        for current_env in utils.get_all_wrappers(eval_env):
            if type(current_env) == BertrandCompetitionDiscreteEnv:
                current_env.c_i = x

        _, _ = evaluate_policy(
            mod,
            eval_env,
            n_eval_episodes=1,
        )


def test_run_visualize(config_dict):

    prefix = "exp_no_stack" if config_dict["test_type"]=='no_stackelberg' else "exp"

    exp_name = prefix + ".%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s" % (
        config_dict['gamma'],
        config_dict['dp_type'],
        config_dict['max_steps'],
        config_dict['algorithm'],
        config_dict['seed'],
        config_dict['tot_num_reward_steps'],
        config_dict['tot_num_eq_steps'],
        config_dict['bbox_state_space_type'],
        config_dict['frac_excluded_eq_steps'],
        config_dict['reward_step_random_price_prob'],
        config_dict['q_restart_rate'],
        config_dict['marginal_cost'],
    )

    log_folder = os.path.join(
        os.path.abspath(os.path.dirname(__file__)) + "/test_logs/heatmaps/", exp_name
    )
    config_dict["log_folder"] = log_folder

    chkp_folder = os.path.join(
        os.path.abspath(os.path.dirname(__file__)) + "/test_logs/", exp_name
    )

    if config_dict["test_type"] == 'no_stackelberg':
        eval_env = get_no_Stackelberg_env(config_dict)
    else:
        eval_env = get_standard_env(config_dict)
    eval_env = EvaluationAfterJohnsonConvergence(eval_env)

    chkpt_path = get_latest_checkpoint(chkp_folder)


    chkpt_path_split = chkpt_path.split('/')

    chkpt_path_spec = chkpt_path_split [len(chkpt_path_split) -1]

    strp = str(chkpt_path_spec).split("_")

    rows = []

    if int(strp[2]) == 50030000:

        if config_dict['algorithm'] == 'A2C':
            from stable_baselines3 import A2C
            try:
                mod = A2C.load(chkpt_path, env=eval_env)
            except (UnboundLocalError, SystemError):
                import rl_trainer_setup
                mod = A2C.load(chkpt_path, env=eval_env, custom_objects={'policy_class': rl_trainer_setup.CustomPolicy})

        qtable_np = np.random.random((eval_env.m,) * eval_env.num_agents)
        for i in range(eval_env.m):
            for j in range(eval_env.m):
                price_i = eval_env.action_price_space[i]
                price_j = eval_env.action_price_space[j]
                prices = np.array([price_i, price_j])
                obs = eval_env.adapt_price_array(np.array([i, j]))
                obs = np.concatenate([obs, np.array([1])])
                temp_action = mod.predict(obs, deterministic=True)[0]
                num_displayed_agents = len(eval_env.get_bbx_idx(prices, temp_action))
                row = []
                row.append(config_dict['bbox_state_space_type'] )
                row.append(config_dict['reward_step_random_price_prob'])
                row.append(i)
                row.append(j)
                row.append(num_displayed_agents)
                rows.append(row)
                qtable_np[i, j] = num_displayed_agents

        column_names = ['bbox_state_space_type', 'reward_step_random_price_prob', "price_i", "price_j", "num_displayed_agents"]

        plt.imshow(qtable_np, cmap='hot', interpolation='nearest', origin='lower')
        plt.colorbar()

        prob = str(config_dict['reward_step_random_price_prob']).split('.')[1]
        plt.savefig( config_dict["log_folder"] + "/../" + str(eval_env.bbox_state_space_type) + '_seed_'+ str(config_dict['seed']) + '_prob'+ str(prob) + '_' + str(strp[2]) + '_heat')

        plt.close()
        plt.cla()

    else:
        print("maximum policy not obtained: ", chkpt_path )


def johnson_run(config_dict):

    if config_dict['dp_type'] != 'dpdp' and config_dict['dp_type'] != 'pdp' and config_dict['dp_type'] != 'no_intervene':
        raise IOError(
            "Buybox needs to be PDP or DPDP or no_intervene to replicate Johnson"
        )

    exp_name = "exp_johnson.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s" % (
        config_dict['gamma'],
        config_dict['dp_type'],
        config_dict['max_steps'],
        config_dict['algorithm'],
        config_dict['seed'],
        config_dict['tot_num_reward_steps'],
        config_dict['tot_num_eq_steps'],
        config_dict['bbox_state_space_type'],
        config_dict['frac_excluded_eq_steps'],
        config_dict['reward_step_random_price_prob'],
        config_dict['q_restart_rate'],
        config_dict['marginal_cost'],
    )

    log_folder = os.path.join(
        os.path.abspath(os.path.dirname(__file__)) + "/logs/", exp_name
    )

    config_dict["log_folder"] = log_folder

    env = get_johnson_env(config_dict)
    env = EvaluationAfterJohnsonConvergence(env)

    env.reset()
    done=False

    while not done:
        dummy_action = env.action_space.sample()
        obs_sub_env, reward_pricing_agents, done, info = env.step(dummy_action)